{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a script that helps you to batch process your code to solve for solution to different scenarios. It is composed of a configuration file (.yaml) and the codes below. Before you run this code, you should make sure the configuration file is set correctly. Here are the explanation for each of the parameters in the configuration file:\n",
    "* input_path: the input directory of CommonRoad scenarios that you indend to solve.\n",
    "* output_path: the output directory of the solution files.\n",
    "* overwrite: the falg to determine whether to overwrite existing solution files.\n",
    "* timeout: timeout time for your motion planner, unit in seconds\n",
    "* trajectory_planner_path: input directory where the module containing the function to execute your motion planner is located\n",
    "* trajectory_planner_module_name: name of the module taht contains the function to execute your motion planner\n",
    "* trajectory_planner_function_name: name of the function that executes your motion planner\n",
    "* default: the parameters specified under this will be applied to all scenarios. if you wish to specify a different paramter for specific scenarios, simply copy the section and replace 'default' with the id of your scenario.\n",
    "* vehicle_model: model of the vehicle, its value could be PM, KS, ST and MB.\n",
    "* vehicle_type type of the vehicle, its value could be FORD_ESCORT, BMW_320i and VW_VANAGON.\n",
    "* cost_function: identifier of cost function. Please refer to [Cost Functions](https://gitlab.lrz.de/tum-cps/commonroad-cost-functions/blob/master/costFunctions_commonRoad.pdf) for more information."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import multiprocessing\n",
    "import yaml\n",
    "import sys\n",
    "import warnings\n",
    "\n",
    "from commonroad.common.file_reader import CommonRoadFileReader\n",
    "from commonroad.common.solution_writer import CommonRoadSolutionWriter, VehicleModel, VehicleType, CostFunction\n",
    "\n",
    "def parse_vehicle_model(model):\n",
    "    if model == 'PM':\n",
    "        cr_model = VehicleModel.PM\n",
    "    elif model == 'ST':\n",
    "        cr_model = VehicleModel.ST\n",
    "    elif model == 'KS':\n",
    "        cr_model = VehicleModel.KS\n",
    "    elif model == 'MB':\n",
    "        cr_model = VehicleModel.MB\n",
    "    else:\n",
    "        raise ValueError('Selected vehicle model is not valid: {}.'.format(model))\n",
    "    return cr_model\n",
    "\n",
    "\n",
    "def parse_vehicle_type(type):\n",
    "    if type == 'FORD_ESCORT':\n",
    "        cr_type = VehicleType.FORD_ESCORT\n",
    "        cr_type_id = 1\n",
    "    elif type == 'BMW_320i':\n",
    "        cr_type = VehicleType.BMW_320i\n",
    "        cr_type_id = 2\n",
    "    elif type == 'VW_VANAGON':\n",
    "        cr_type = VehicleType.VW_VANAGON\n",
    "        cr_type_id = 3\n",
    "    else:\n",
    "        raise ValueError('Selected vehicle type is not valid: {}.'.format(type))\n",
    "        \n",
    "    return cr_type, cr_type_id\n",
    "\n",
    "\n",
    "def parse_cost_function(cost):\n",
    "    if cost == 'JB1':\n",
    "        cr_cost = CostFunction.JB1\n",
    "    elif cost == 'SA1':\n",
    "        cr_cost = CostFunction.SA1\n",
    "    elif cost == 'WX1':\n",
    "        cr_cost = CostFunction.WX1\n",
    "    elif cost == 'SM1':\n",
    "        cr_cost = CostFunction.SM1\n",
    "    elif cost == 'SM2':\n",
    "        cr_cost = CostFunction.SM2\n",
    "    elif cost == 'SM3':\n",
    "        cr_cost = CostFunction.SM3\n",
    "    else:\n",
    "        raise ValueError('Selected cost function is not valid: {}.'.format(cost))\n",
    "    return cr_cost\n",
    "\n",
    "\n",
    "def call_trajectory_planner(queue, function, scenario, planning_problem_set, vehicle_type_id):\n",
    "    queue.put(function(scenario, planning_problem_set, vehicle_type_id))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open config file\n",
    "with open('batch_processing_config.yaml', 'r') as stream:\n",
    "    try:\n",
    "        settings = yaml.load(stream)\n",
    "    except yaml.YAMLError as exc:\n",
    "        print(exc)\n",
    "\n",
    "# get planning wrapper function\n",
    "sys.path.append(os.getcwd() + os.path.dirname(settings['trajectory_planner_path']))\n",
    "module = __import__(settings['trajectory_planner_module_name'])\n",
    "function = getattr(module, settings['trajectory_planner_function_name'])\n",
    "time_timeout = settings['timeout']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start Processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# iterate through scenarios\n",
    "num_files = len(os.listdir(settings['input_path']))\n",
    "\n",
    "print(\"Total number of files to be processed: {}\".format(num_files))\n",
    "print(\"Timeout setting: {} seconds\\n\".format(time_timeout))\n",
    "    \n",
    "count_processed = 0\n",
    "for filename in os.listdir(settings['input_path']):\n",
    "    count_processed += 1\n",
    "    print(\"File No. {} / {}\".format(count_processed, num_files))\n",
    "    \n",
    "    if not filename.endswith('.xml'):\n",
    "        continue\n",
    "        \n",
    "    fullname = os.path.join(settings['input_path'], filename)\n",
    "\n",
    "    print(\"Started processing scenario {}\".format(filename))\n",
    "    scenario, planning_problem_set = CommonRoadFileReader(fullname).open()\n",
    "\n",
    "    # get settings for each scenario\n",
    "    if scenario.benchmark_id in settings.keys():\n",
    "        # specific\n",
    "        vehicle_model = parse_vehicle_model(settings[scenario.benchmark_id]['vehicle_model'])\n",
    "        vehicle_type,vehicle_type_id = parse_vehicle_type(settings[scenario.benchmark_id]['vehicle_type'])\n",
    "        cost_function = parse_cost_function(settings[scenario.benchmark_id]['cost_function'])\n",
    "    else:\n",
    "        # default\n",
    "        vehicle_model = parse_vehicle_model(settings['default']['vehicle_model'])\n",
    "        vehicle_type, vehicle_type_id = parse_vehicle_type(settings['default']['vehicle_type'])\n",
    "        cost_function = parse_cost_function(settings['default']['cost_function'])\n",
    "        \n",
    "    queue = multiprocessing.Queue()\n",
    "    # create process, pass in required arguements\n",
    "    p = multiprocessing.Process(target=call_trajectory_planner, name=\"trajectory_planner\",\n",
    "                                args=(queue, function, scenario, planning_problem_set, vehicle_type_id))\n",
    "    # start planning\n",
    "    p.start()\n",
    "    \n",
    "    # wait till process ends or skip if timed out\n",
    "    p.join(timeout=time_timeout)\n",
    "\n",
    "    if p.is_alive():\n",
    "        print(\"===> Trajectory planner timeout.\")\n",
    "        p.terminate()\n",
    "        p.join()\n",
    "        solution_trajectories = {}\n",
    "    else:\n",
    "        print(\"Planning finished.\")\n",
    "        solution_trajectories = queue.get()\n",
    "\n",
    "    # create path for solutions\n",
    "    pathlib.Path(settings['output_path']).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    error = False\n",
    "    cr_solution_writer = CommonRoadSolutionWriter(settings['output_path'], \n",
    "                                                  scenario.benchmark_id, \n",
    "                                                  scenario.dt,\n",
    "                                                  vehicle_type, \n",
    "                                                  vehicle_model, \n",
    "                                                  cost_function)\n",
    "    \n",
    "    # inspect whether all planning problems are solved\n",
    "    for planning_problem_id, planning_problem in planning_problem_set.planning_problem_dict.items():\n",
    "        if planning_problem_id not in solution_trajectories.keys():\n",
    "            print('Solution for planning problem with ID={} is not provided for scenario {}. Solution skipped.'.format(\n",
    "                planning_problem_id, filename))\n",
    "            error = True\n",
    "            break\n",
    "        else:\n",
    "            cr_solution_writer.add_solution_trajectory(\n",
    "                solution_trajectories[planning_problem_id], planning_problem_id)\n",
    "    if not error:\n",
    "        cr_solution_writer.write_to_file(overwrite=settings['overwrite'])\n",
    "\n",
    "    print(\"=========================================================\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (cr37)",
   "language": "python",
   "name": "cr37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
